Triton入门精选Puzzles
原版:
srush/Triton-Puzzles: Puzzles for learning Triton
缩减版:
参考答案:
alexzhang13/Triton-Puzzles-Solutions: Personal solutions to the Triton Puzzles
Puzzle 2: 涉及block的常量相加
B0小于x的长度N0,注意不要数组越界,load和store都需要添加mask
@triton.jit
def add_mask2_kernel(x_ptr, z_ptr, N0, B0: tl.constexpr):
pid_x = tl.program_id(0)
off_x = tl.arange(0, B0) + pid_x * B0
mask = off_x < N0
x = tl.load(x_ptr + off_x,mask,0)
x = add2_spec(x)
tl.store(z_ptr + off_x, x, mask)
Puzzle 4:涉及block的数组相加
@triton.jit
def add_vec_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x*B0 + tl.arange(0,B0)
off_y = block_id_y*B1 + tl.arange(0,B1)
mask_x = off_x < N0
mask_y = off_y < N1
x = tl.load(x_ptr + off_x,mask_x,0)
y = tl.load(y_ptr + off_y,mask_y,0)
# x是一行,y是一列
z = x[None,:]+y[:,None]
off_z = off_x[None,:]+off_y[:,None]*N0
mask_z = mask_x[None,:] & mask_y[:,None]
tl.store(z_ptr+off_z,z,mask_z)
tl.store(z_ptr+off_z,z)
return
一定要有mask!!!
Puzle 5: Relu
@triton.jit
def mul_relu_block_kernel(
x_ptr, y_ptr, z_ptr, N0, N1, B0: tl.constexpr, B1: tl.constexpr
):
block_id_x = tl.program_id(0)
block_id_y = tl.program_id(1)
off_x = block_id_x*B0 + tl.arange(0,B0)
off_y = block_id_y*B1 + tl.arange(0,B1)
mask_x = off_x < N0
mask_y = off_y < N1
x = tl.load(x_ptr+off_x,mask_x,0)
y = tl.load(y_ptr+off_y,mask_y,0)
z = x[None,:]*y[:,None]
z = tl.where(z>=0,z,0)
off_z = off_x[None,:] + off_y[:,None]*N0
mask_z = mask_x[None,:]&mask_y[:,None]
tl.store(z_ptr+off_z,z,mask_z)
return